package sqlx

import (
	"context"
	"database/sql"
	"errors"

	"gitee.com/yanwc/gozero-utils/errx"
	"github.com/zeromicro/go-zero/core/stores/sqlx"
)

func ExecSql(session sqlx.Session, pre func() (string, []interface{}, error), done func(result sql.Result) error, errLabel string) error {
	return ExecSqlCtx(context.Background(), session, pre, done, errLabel)
}

func ExecSqlCtx(ctx context.Context, session sqlx.Session, pre func() (string, []interface{}, error), done func(result sql.Result) error, errLabel string) error {
	execSql, args, err := pre()
	if err != nil {
		return errx.New(errx.DB_QUERY, errx.WithMsgOption(errLabel), errx.WithErrorOption(err))
	}

	result, err := session.ExecCtx(ctx, execSql, args...)
	if err != nil {
		return errx.New(errx.DB, errx.WithMsgOption(errLabel), errx.WithErrorOption(err))
	}

	if done != nil {
		if err := done(result); err != nil {
			return err
		}
	}

	return nil
}

func ExecInsertSql(session sqlx.Session, pre func() (string, []interface{}, error), done func(lastInsertId int64) error, errLabel string) error {
	return ExecSqlCtx(context.Background(), session, pre, func(result sql.Result) error {
		lastInsertId, err := result.LastInsertId()
		if err != nil {
			return err
		}

		if err := done(lastInsertId); err != nil {
			return err
		}

		return nil
	}, errLabel)
}

func queryRowCtx(ctx context.Context, session sqlx.Session, pre func() (string, []interface{}, error), data interface{}, errLabel string, must bool) error {
	query, args, err := pre()
	if err != nil {
		return errx.New(errx.DB_QUERY, errx.WithMsgOption(errLabel), errx.WithErrorOption(err))
	}

	err = session.QueryRowCtx(ctx, data, query, args...)
	if err != nil {
		if errors.Is(err, sqlx.ErrNotFound) {
			if must {
				return errx.New(errx.DB_NOT_FOUND, errx.WithMsgOption(errLabel))
			}
		} else {
			return errx.New(errx.DB_QUERY, errx.WithMsgOption(errLabel), errx.WithErrorOption(err))
		}
	}

	return nil
}

func QueryRowCtx(ctx context.Context, session sqlx.Session, pre func() (string, []interface{}, error), data interface{}, errLabel string) error {
	return queryRowCtx(ctx, session, pre, data, errLabel, false)
}

func QueryRowCtxMust(ctx context.Context, session sqlx.Session, pre func() (string, []interface{}, error), data interface{}, errLabel string) error {
	return queryRowCtx(ctx, session, pre, data, errLabel, true)
}

func QueryRow(session sqlx.Session, pre func() (string, []interface{}, error), data interface{}, errLabel string) error {
	return queryRowCtx(context.Background(), session, pre, data, errLabel, false)
}

func QueryRowMust(session sqlx.Session, pre func() (string, []interface{}, error), data interface{}, errLabel string) error {
	return queryRowCtx(context.Background(), session, pre, data, errLabel, true)
}

func QueryRowsCtx(ctx context.Context, session sqlx.Session, pre func() (string, []interface{}, error), data interface{}, errLabel string) error {
	query, args, err := pre()
	if err != nil {
		return errx.New(errx.DB_QUERY, errx.WithMsgOption(errLabel), errx.WithErrorOption(err))
	}

	err = session.QueryRowsCtx(ctx, data, query, args...)
	if err != nil {
		if errors.Is(err, sqlx.ErrNotFound) {
			return errx.New(errx.DB_NOT_FOUND, errx.WithMsgOption(errLabel))
		} else {
		}
	}

	return nil
}

func QueryRows(session sqlx.Session, pre func() (string, []interface{}, error), data interface{}, errLabel string) error {
	return QueryRowsCtx(context.Background(), session, pre, data, errLabel)
}
